import contextlib
import gc
import sys
from functools import partial
from pathlib import Path
from typing import Dict, Literal, Optional, Tuple, Union
from dataclasses import asdict
import json
import torch
from os import remove

# support running without installing as a package
# ruff: noqa: E402
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt.config_base import ConfigBase as Config
from lit_gpt.utils_old import NotYetLoadedTensor, incremental_save, lazy_load
# from scripts.convert_hf_checkpoint import layer_template, load_param

from bin_to_safetensors import convert_file


def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
    split = layer_name.split('.')
    number = int(split[idx])
    split[idx] = '{}'
    from_name = '.'.join(split)
    return from_name, number


def load_param(
    param: Union[torch.Tensor, NotYetLoadedTensor],
    name: str,
    dtype: Optional[torch.dtype],
) -> torch.Tensor:
    if hasattr(param, '_load_tensor'):
        # support tensors loaded via `lazy_load()`
        print(f'Loading {name!r} into RAM')
        param = param._load_tensor()
    if (
        dtype is not None
        and type(dtype) is not NotYetLoadedTensor
        and dtype != param.dtype
    ):
        print(f'Converting {name!r} from {param.dtype} to {dtype}')
        param = param.to(dtype)
    return param


def copy_weights_falcon(
    size: Literal['7b', '40b'],
    state_dict: Dict[str, torch.Tensor],
    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
    saver: Optional[incremental_save] = None,
):
    weight_map = {
        'transformer.wte.weight': 'transformer.word_embeddings.weight',
        'transformer.h.{}.attn.attn.weight': 'transformer.h.{}.self_attention.query_key_value.weight',
        'transformer.h.{}.attn.proj.weight': 'transformer.h.{}.self_attention.dense.weight',
        'transformer.h.{}.mlp.fc.weight': 'transformer.h.{}.mlp.dense_h_to_4h.weight',
        'transformer.h.{}.mlp.proj.weight': 'transformer.h.{}.mlp.dense_4h_to_h.weight',
        'transformer.ln_f.bias': 'transformer.ln_f.bias',
        'transformer.ln_f.weight': 'transformer.ln_f.weight',
        'lm_head.weight': 'lm_head.weight',
    }
    # the original model definition is different for each size
    if size == '7b':
        weight_map.update(
            {
                'transformer.h.{}.norm_1.bias': 'transformer.h.{}.input_layernorm.bias',
                'transformer.h.{}.norm_1.weight': 'transformer.h.{}.input_layernorm.weight',
            }
        )
    elif size == '40b':
        weight_map.update(
            {
                'transformer.h.{}.norm_1.bias': 'transformer.h.{}.ln_attn.bias',
                'transformer.h.{}.norm_1.weight': 'transformer.h.{}.ln_attn.weight',
                'transformer.h.{}.norm_2.bias': 'transformer.h.{}.ln_mlp.bias',
                'transformer.h.{}.norm_2.weight': 'transformer.h.{}.ln_mlp.weight',
            }
        )
    else:
        raise NotImplementedError

    for name, param in lit_weights.items():
        if 'transformer.h' in name:
            from_name, number = layer_template(name, 2)
            to_name = weight_map[from_name].format(number)
        else:
            to_name = weight_map[name]
        param = load_param(param, name, None)
        if saver is not None:
            param = saver.store_early(param)
        state_dict[to_name] = param


def copy_weights_gpt_neox(
    state_dict: Dict[str, torch.Tensor],
    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
    saver: Optional[incremental_save] = None,
) -> None:
    weight_map = {
        'transformer.wte.weight': 'gpt_neox.embed_in.weight',
        'transformer.h.{}.norm_1.bias': 'gpt_neox.layers.{}.input_layernorm.bias',
        'transformer.h.{}.norm_1.weight': 'gpt_neox.layers.{}.input_layernorm.weight',
        'transformer.h.{}.attn.attn.bias': 'gpt_neox.layers.{}.attention.query_key_value.bias',
        'transformer.h.{}.attn.attn.weight': 'gpt_neox.layers.{}.attention.query_key_value.weight',
        'transformer.h.{}.attn.proj.bias': 'gpt_neox.layers.{}.attention.dense.bias',
        'transformer.h.{}.attn.proj.weight': 'gpt_neox.layers.{}.attention.dense.weight',
        'transformer.h.{}.norm_2.bias': 'gpt_neox.layers.{}.post_attention_layernorm.bias',
        'transformer.h.{}.norm_2.weight': 'gpt_neox.layers.{}.post_attention_layernorm.weight',
        'transformer.h.{}.mlp.fc.bias': 'gpt_neox.layers.{}.mlp.dense_h_to_4h.bias',
        'transformer.h.{}.mlp.fc.weight': 'gpt_neox.layers.{}.mlp.dense_h_to_4h.weight',
        'transformer.h.{}.mlp.proj.bias': 'gpt_neox.layers.{}.mlp.dense_4h_to_h.bias',
        'transformer.h.{}.mlp.proj.weight': 'gpt_neox.layers.{}.mlp.dense_4h_to_h.weight',
        'transformer.ln_f.bias': 'gpt_neox.final_layer_norm.bias',
        'transformer.ln_f.weight': 'gpt_neox.final_layer_norm.weight',
        'lm_head.weight': 'embed_out.weight',
    }

    for name, param in lit_weights.items():
        if 'transformer.h' in name:
            from_name, number = layer_template(name, 2)
            to_name = weight_map[from_name].format(number)
        else:
            to_name = weight_map[name]
        param = load_param(param, name, None)
        if saver is not None:
            param = saver.store_early(param)
        state_dict[to_name] = param


def copy_weights_llama(
    config: Config,
    state_dict: Dict[str, torch.Tensor],
    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
    saver: Optional[incremental_save] = None,
):
    weight_map = {
        'transformer.wte.weight': 'model.embed_tokens.weight',
        'transformer.h.{}.norm_1.weight': 'model.layers.{}.input_layernorm.weight',
        'transformer.h.{}.attn.proj.weight': 'model.layers.{}.self_attn.o_proj.weight',
        'transformer.h.{}.norm_2.weight': 'model.layers.{}.post_attention_layernorm.weight',
        'transformer.h.{}.mlp.swiglu.w1.weight': 'model.layers.{}.mlp.gate_proj.weight',
        'transformer.h.{}.mlp.swiglu.w2.weight': 'model.layers.{}.mlp.up_proj.weight',
        'transformer.h.{}.mlp.swiglu.w3.weight': 'model.layers.{}.mlp.down_proj.weight',
        'transformer.ln_f.weight': 'model.norm.weight',
        'lm_head.weight': 'lm_head.weight',
    }
    for name, param in lit_weights.items():
        if name.endswith('.attn.attn.weight'):
            from_name, number = layer_template(name, 2)
            q = 'model.layers.{}.self_attn.q_proj.weight'.format(number)
            k = 'model.layers.{}.self_attn.k_proj.weight'.format(number)
            v = 'model.layers.{}.self_attn.v_proj.weight'.format(number)
            qkv = load_param(param, name, None)
            qp, kp, vp = tensor_split(qkv, config)
            for to_name, param in zip((q, k, v), (qp, kp, vp)):
                if saver is not None:
                    param = saver.store_early(param)
                state_dict[to_name] = param
        elif 'transformer.h' in name:
            from_name, number = layer_template(name, 2)
            to_name = weight_map[from_name]

            if to_name is None:
                continue
            to_name = to_name.format(number)
            param = load_param(param, name, None)
            if saver is not None:
                param = saver.store_early(param)
            state_dict[to_name] = param

        else:
            to_name = weight_map[name]
            param = load_param(param, name, None)
            if saver is not None:
                param = saver.store_early(param)
            state_dict[to_name] = param


# This is updated from litgpt repo (10 Mar 2024)
# Supports MoE
# We need to update it for xformers SwiGlu because the mlp layers are in `transformer.h.{}.mlp.swiglu`
def copy_weights_llama_2(
    config: Config,
    state_dict: Dict[str, torch.Tensor],
    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
    untie_weights: bool = False,
    saver: Optional[incremental_save] = None,
) -> None:
    weight_map = {
        'transformer.wte.weight': 'model.embed_tokens.weight',
        'transformer.h.{}.norm_1.weight': 'model.layers.{l}.input_layernorm.weight',
        'transformer.h.{}.norm_1.bias': 'model.layers.{l}.input_layernorm.bias',
        'transformer.h.{}.attn.proj.weight': 'model.layers.{l}.self_attn.o_proj.weight',
        'transformer.h.{}.norm_2.weight': 'model.layers.{l}.post_attention_layernorm.weight',
        'transformer.h.{}.norm_2.bias': 'model.layers.{l}.post_attention_layernorm.bias',
        'transformer.ln_f.weight': 'model.norm.weight',
        'transformer.ln_f.bias': 'model.norm.bias',
        'lm_head.weight': 'lm_head.weight',
    }
    if config._mlp_class == 'LLaMAMoE':
        weight_map.update(
            {
                'transformer.h.{}.mlp.gate.weight': 'model.layers.{l}.block_sparse_moe.gate.weight',
                'transformer.h.{}.mlp.experts.{}.fc_1.weight': 'model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight',
                'transformer.h.{}.mlp.experts.{}.fc_2.weight': 'model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight',
                'transformer.h.{}.mlp.experts.{}.proj.weight': 'model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight',
            }
        )
    elif config._mlp_class in ('LLaMAMLP', 'GemmaMLP'):
        weight_map.update(
            {
                'transformer.h.{}.mlp.fc_1.weight': 'model.layers.{l}.mlp.gate_proj.weight',
                'transformer.h.{}.mlp.fc_2.weight': 'model.layers.{l}.mlp.up_proj.weight',
                'transformer.h.{}.mlp.proj.weight': 'model.layers.{l}.mlp.down_proj.weight',
            }
        )
    else:
        raise NotImplementedError

    for name, param in lit_weights.items():
        if name == 'lm_head.weight' and untie_weights:
            continue
        if name.endswith('.attn.attn.weight'):
            from_name, l = layer_template(name, 2)
            q = 'model.layers.{}.self_attn.q_proj.weight'.format(l)
            k = 'model.layers.{}.self_attn.k_proj.weight'.format(l)
            v = 'model.layers.{}.self_attn.v_proj.weight'.format(l)
            qkv = load_param(param, name, None)
            qp, kp, vp = qkv_split(qkv, config)
            for to_name, param in zip((q, k, v), (qp, kp, vp)):
                if saver is not None:
                    param = saver.store_early(param)
                state_dict[to_name] = param
        else:
            if 'transformer.h' in name:
                from_name, l = layer_template(name, 2)
                e = None
                if 'mlp.experts' in name:
                    from_name, e = layer_template(from_name, 5)
                to_name = weight_map[from_name]
                to_name = to_name.format(l=l, e=e)
            else:
                to_name = weight_map[name]
            param = load_param(param, name, None)
            if saver is not None:
                param = saver.store_early(param)
            state_dict[to_name] = param


# This is updated from litgpt repo (10 Mar 2024)
def qkv_split(
    param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    q_per_kv = config.n_head // config.n_query_groups
    qs = []
    ks = []
    vs = []
    for chunk in torch.chunk(param, config.n_query_groups):
        split = torch.split(
            chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]
        )
        qs.append(split[0])
        ks.append(split[1])
        vs.append(split[2])
    q = torch.cat(qs)
    k = torch.cat(ks)
    v = torch.cat(vs)
    return q, k, v


def tensor_split(
    param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    def kstart(start, blen, klen) -> int:
        """returns start index of keys in batch"""
        return start + (blen - (klen * 2))

    def vstart(start, blen, klen) -> int:
        """returns start index of values in batch"""
        return start + blen - klen

    def vend(start, blen) -> int:
        """returns last index of values in batch"""
        return start + blen

    # num observations
    nobs = param.shape[0]
    # batch length
    blen = nobs // config.n_query_groups
    # key length in batch
    klen = config.head_size
    # value length in batch
    vlen = config.head_size
    # the starting index of each new batch
    starts = range(0, nobs, blen)
    # the indices to splice on
    splices = [
        (s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts
    ]

    qc = ()
    kc = ()
    vc = ()

    for splice in splices:
        qs, ks, vs, ve = splice
        qc += (param[qs:ks, :],)
        kc += (param[ks:vs, :],)
        vc += (param[vs:ve, :],)

    q = torch.cat(qc)
    k = torch.cat(kc)
    v = torch.cat(vc)

    return q, k, v


def maybe_unwrap_state_dict(
    lit_weights: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
    return lit_weights.get('model', lit_weights)


def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None:
    weight_names = {wk.split('.')[-1] for wk in lit_weights}
    # LoRA or QLoRA
    if any('lora' in wn for wn in weight_names):
        raise ValueError(
            'Model weights must be merged using `lora.merge_lora_weights()` before conversion.'
        )
    # adapter v2. adapter_bias will only be in adapter_v2
    elif 'adapter_bias' in weight_names:
        raise NotImplementedError(
            'Converting models finetuned with adapter_v2 not yet supported.'
        )
    # adapter. gating_factor is in adapter and adapter_v2
    elif 'gating_factor' in weight_names:
        raise NotImplementedError(
            'Converting models finetuned with adapter not yet supported.'
        )


def get_pints_init_hf_config() -> dict:
    return {
        'architectures': ['LlamaForCausalLM'],
        'bos_token_id': 1,
        'eos_token_id': 2,
        'pad_token_id': 32001,
        'hidden_act': 'silu',
        'initializer_range': 0.02,
        'model_type': 'llama',
        'pretraining_tp': 1,
        'rope_scaling': None,
        'tie_word_embeddings': False,
        'torch_dtype': 'float32',
        'transformers_version': '4.38.0',
        'use_cache': True,
    }


def convert_config_lit_to_hf(lit_config_dict: dict) -> dict:
    lit_hf_mapping = {
        'block_size': 'max_position_embeddings',
        'vocab_size': 'vocab_size',
        'n_layer': 'num_hidden_layers',
        'n_embd': 'hidden_size',
        'n_head': 'num_attention_heads',
        'n_query_groups': 'num_key_value_heads',
        'intermediate_size': 'intermediate_size',
        'norm_eps': 'rms_norm_eps',
    }
    print(
        'CONFIG IS HARDCODED FOR PINTS MODEL. You will need to edit the config.json output for other models.'
    )
    hf_config_dict = get_pints_init_hf_config()

    for lit_key, hf_key in lit_hf_mapping.items():
        hf_config_dict[hf_key] = lit_config_dict[lit_key]
    return hf_config_dict


@torch.inference_mode()
def convert_lit_checkpoint(
    *,
    checkpoint_name: str,
    directory: Path,
    output_directory: Path,
    model_name: str,
    output_config=True,
    safetensors=True,
    # Safetensors is better, and with pytorch_model.bin around, it always gets loaded.
    # It is safe to delete it unless in special cases where it is needed.
    delete_pytorch_model=True,
) -> None:
    """
    Converts lit checkpoint to pytorch model and safetensors

    Args:
        checkpoint_name: Filename of the checkpoint
        directory: Directory where the checkpoint resides. Is also the output directory.
        model_name: The `name` of the model as defined in `lit_gpt/config_base.py`. E.g, "0.12-Pint" or "1.5-Pints". This is to extract the model configuration necessary for conversion.
        output_config: Produce config.json.
        safetensors: Output safetensors. This is the better format over pytorch_model.
        delete_pytorch_model: If safetensors=True, we can delete pytorch_model.bin to force Huggingface to use safetensors.
    """
    config = Config.from_name(model_name)

    if 'falcon' in model_name:
        copy_fn = partial(copy_weights_falcon, '40b' if config.n_embd == 8192 else '7b')
    elif config._mlp_class == 'LLaMAMLP':
        copy_fn = partial(copy_weights_llama, config)
    else:
        copy_fn = copy_weights_gpt_neox

    # initialize a new empty state dict to hold our new weights
    sd = {}

    # checkpoint_name cannot be hardcoded because there exists different outputs such as
    # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"")
    pth_file = directory / checkpoint_name
    pth_file = pth_file.absolute().resolve()
    print('Reading checkpoint:', str(pth_file))

    bin_filename = 'pytorch_model.bin'
    bin_file = directory / bin_filename
    bin_file = bin_file.absolute().resolve()

    # If there is a pytorch_model already existing, check if user wants to overwrite.
    if bin_file.is_file():
        match input(
            f'WARNING: `{bin_filename}` already exist. Are you sure you want to overwrite? (yes/no): '
        ):
            case 'yes':
                pass
            case _:
                exit()

    with incremental_save(bin_file) as saver:
        with contextlib.ExitStack() as stack:
            lit_weights = stack.enter_context(lazy_load(pth_file))
            lit_weights = maybe_unwrap_state_dict(lit_weights)
            check_conversion_supported(lit_weights)
            # Incremental save will trigger error
            copy_fn(sd, lit_weights, saver=None)
            gc.collect()
        saver.save(sd)

    print('Checkpoint converted to pytorch and saved to:', str(bin_file))
    print('Converting to safetensors...')

    # Also make safetensors
    if safetensors:
        safetensors_filename = 'model.safetensors'
        safetensors_file = output_directory / safetensors_filename

        if safetensors_file.is_file():
            match input(
                f'WARNING: `{safetensors_filename}` already exist. Are you sure you want to overwrite? (yes/no): '
            ):
                case 'yes':
                    pass
                case _:
                    exit()

        convert_file(bin_file, safetensors_file)
        print(
            'Pytorch model converted to safetensors and saved to:',
            str(safetensors_file),
        )
        if delete_pytorch_model:
            print('Deleting pytorch_model.bin...')
            remove(bin_file)
            print('pytorch_model.bin removed.')

    # convert lit config file to hf-style
    if output_config:
        print('Converting config file...')

        config_filename = 'config.json'
        config_path = output_directory / config_filename

        if config_path.is_file():
            match input(
                f'WARNING: `{config_filename}` already exist. Are you sure you want to overwrite? (yes/no): '
            ):
                case 'yes':
                    pass
                case _:
                    exit()

        lit_config = asdict(config)
        hf_config = convert_config_lit_to_hf(lit_config)
        config_path = output_directory / 'config.json'
        with open(config_path, 'w') as f:
            json.dump(hf_config, f, indent=4)

    print('Done!')


if __name__ == '__main__':
    from jsonargparse import CLI

    CLI(convert_lit_checkpoint, as_positional=False)
